Usage¶

svg_heatmap can be used as a drop-in replacement for seaborn.heatmap, with the exception of a few missing features:

  • center
  • annot
    • fmt
  • linewidths, linecolor
  • square
  • mask
Please note that this does not use any internal seaborn functions, neither use matplotlib.pyplot.pcolormesh for the color mesh. It will therefore look slightly different than a corresponding seaborn plot.

Imports¶

In [1]:
import numpy as np; np.random.seed(0)
import seaborn as sns; sns.set()
from matplotlib import pyplot as plt
from svg_heatmap import heatmap
from ipywidgets import HTML
import sys
from io import BytesIO
import binascii

Comparing plots¶

In [2]:
def compare_plots(data, svg_kwargs={}, sns_kwargs={}, **kwargs):
    svg_plot = heatmap(data, **svg_kwargs ,**kwargs)
    
    fig=plt.figure()
    sns.heatmap(data, **sns_kwargs, **kwargs)
    plt.tight_layout()
    with BytesIO() as buf:
        fig.canvas.print_png(buf)
        png_data = binascii.b2a_base64(buf.getvalue()).decode()
        
    png_html = '<img src="data:image/png;base64,{}">'
    sns_png_plot= png_html.format(png_data)
    
    with BytesIO() as buf:
        plt.savefig(buf, format='svg')
        sns_svg_plot = buf.getvalue().decode()
        
    plt.close()
    
    svg_size, sns_png_size, sns_svg_size = [str(round(sys.getsizeof(plot) / 1024, 1)) + 'kB' 
                                        for plot in (svg_plot, sns_png_plot, sns_svg_plot)]
    output = ''
    return HTML('svg {}<br>'.format(svg_size) + svg_plot +'<br>sns svg {}<br>'.format(sns_svg_size) + sns_svg_plot
               +'<br>sns png {}<br>'.format(sns_png_size) + sns_png_plot)

ndarray data¶

In [3]:
compare_plots(np.random.rand(10, 12), cmap='viridis')

DataFrame data¶

In [4]:
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
compare_plots(flights, cmap='magma')

No cbar¶

In [5]:
compare_plots(flights, cmap='magma', cbar=False)

log scaling¶

In [6]:
data_w_outliers = np.random.rand(10, 12)
data_w_outliers[2:3,3:6] += 5*data_w_outliers.max()
In [7]:
from matplotlib.colors import LogNorm
compare_plots(data_w_outliers, cmap='magma', svg_kwargs=dict(log_scaling=True), 
              sns_kwargs=dict(norm=LogNorm(vmin=data_w_outliers.min(), vmax=data_w_outliers.max())))